#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@File      :   iRMB.py
@Time      :   2024/02/26 20:16:05
@Author    :   CSDN迪菲赫尔曼 
@Version   :   1.0
@Reference :   https://blog.csdn.net/weixin_43694096
@Desc      :   None
"""


import math
import torch
import torch.nn as nn
from functools import partial
from einops import rearrange, reduce

from timm.layers.activations import *
from timm.layers import DropPath


__all__ = "iRMB"
inplace = True


# SE
class SE(nn.Module):
    def __init__(self, c1, ratio=16):
        super(SE, self).__init__()
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.l1 = nn.Linear(c1, c1 // ratio, bias=False)
        self.relu = nn.ReLU(inplace=True)
        self.l2 = nn.Linear(c1 // ratio, c1, bias=False)
        self.sig = nn.Sigmoid()

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avgpool(x).view(b, c)
        y = self.l1(y)
        y = self.relu(y)
        y = self.l2(y)
        y = self.sig(y)
        y = y.view(b, c, 1, 1)
        return x * y.expand_as(x)


class LayerNorm2d(nn.Module):
    def __init__(self, normalized_shape, eps=1e-6, elementwise_affine=True):
        super().__init__()
        self.norm = nn.LayerNorm(normalized_shape, eps, elementwise_affine)

    def forward(self, x):
        x = rearrange(x, "b c h w -> b h w c").contiguous()
        x = self.norm(x)
        x = rearrange(x, "b h w c -> b c h w").contiguous()
        return x


def get_norm(norm_layer="in_1d"):
    eps = 1e-6
    norm_dict = {
        "none": nn.Identity,
        "in_1d": partial(nn.InstanceNorm1d, eps=eps),
        "in_2d": partial(nn.InstanceNorm2d, eps=eps),
        "in_3d": partial(nn.InstanceNorm3d, eps=eps),
        "bn_1d": partial(nn.BatchNorm1d, eps=eps),
        "bn_2d": partial(nn.BatchNorm2d, eps=eps),
        # 'bn_2d': partial(nn.SyncBatchNorm, eps=eps),
        "bn_3d": partial(nn.BatchNorm3d, eps=eps),
        "gn": partial(nn.GroupNorm, eps=eps),
        "ln_1d": partial(nn.LayerNorm, eps=eps),
        "ln_2d": partial(LayerNorm2d, eps=eps),
    }
    return norm_dict[norm_layer]


def get_act(act_layer="relu"):
    act_dict = {
        "none": nn.Identity,
        "sigmoid": Sigmoid,
        "swish": Swish,
        "mish": Mish,
        "hsigmoid": HardSigmoid,
        "hswish": HardSwish,
        "hmish": HardMish,
        "tanh": Tanh,
        "relu": nn.ReLU,
        "relu6": nn.ReLU6,
        "prelu": PReLU,
        "gelu": GELU,
        "silu": nn.SiLU,
    }
    return act_dict[act_layer]


class LayerScale(nn.Module):
    def __init__(self, dim, init_values=1e-5, inplace=True):
        super().__init__()
        self.inplace = inplace
        self.gamma = nn.Parameter(init_values * torch.ones(1, 1, dim))

    def forward(self, x):
        return x.mul_(self.gamma) if self.inplace else x * self.gamma


class LayerScale2D(nn.Module):
    def __init__(self, dim, init_values=1e-5, inplace=True):
        super().__init__()
        self.inplace = inplace
        self.gamma = nn.Parameter(init_values * torch.ones(1, dim, 1, 1))

    def forward(self, x):
        return x.mul_(self.gamma) if self.inplace else x * self.gamma


class ConvNormAct(nn.Module):
    def __init__(
        self,
        dim_in,
        dim_out,
        kernel_size,
        stride=1,
        dilation=1,
        groups=1,
        bias=False,
        skip=False,
        norm_layer="bn_2d",
        act_layer="relu",
        inplace=True,
        drop_path_rate=0.0,
    ):
        super(ConvNormAct, self).__init__()
        self.has_skip = skip and dim_in == dim_out
        padding = math.ceil((kernel_size - stride) / 2)
        self.conv = nn.Conv2d(
            dim_in, dim_out, kernel_size, stride, padding, dilation, groups, bias
        )
        self.norm = get_norm(norm_layer)(dim_out)
        self.act = get_act(act_layer)(inplace=inplace)
        self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity()

    def forward(self, x):
        shortcut = x
        x = self.conv(x)
        x = self.norm(x)
        x = self.act(x)
        if self.has_skip:
            x = self.drop_path(x) + shortcut
        return x


# ========== Multi-Scale Populations, for down-sampling and inductive bias ==========
class MSPatchEmb(nn.Module):
    def __init__(
        self,
        dim_in,
        emb_dim,
        kernel_size=2,
        c_group=-1,
        stride=1,
        dilations=[1, 2, 3],
        norm_layer="bn_2d",
        act_layer="silu",
    ):
        super().__init__()
        self.dilation_num = len(dilations)
        assert dim_in % c_group == 0
        c_group = math.gcd(dim_in, emb_dim) if c_group == -1 else c_group
        self.convs = nn.ModuleList()
        for i in range(len(dilations)):
            padding = math.ceil(((kernel_size - 1) * dilations[i] + 1 - stride) / 2)
            self.convs.append(
                nn.Sequential(
                    nn.Conv2d(
                        dim_in,
                        emb_dim,
                        kernel_size,
                        stride,
                        padding,
                        dilations[i],
                        groups=c_group,
                    ),
                    get_norm(norm_layer)(emb_dim),
                    get_act(act_layer)(emb_dim),
                )
            )

    def forward(self, x):
        if self.dilation_num == 1:
            x = self.convs[0](x)
        else:
            x = torch.cat(
                [self.convs[i](x).unsqueeze(dim=-1) for i in range(self.dilation_num)],
                dim=-1,
            )
            x = reduce(x, "b c h w n -> b c h w", "mean").contiguous()
        return x


class iRMB(nn.Module):
    def __init__(
        self,
        dim_in,
        dim_out,
        norm_in=True,
        has_skip=True,
        exp_ratio=1.0,
        norm_layer="bn_2d",
        act_layer="relu",
        v_proj=True,
        dw_ks=3,
        stride=1,
        dilation=1,
        se_ratio=0.0,
        dim_head=64,
        window_size=7,
        attn_s=True,
        qkv_bias=False,
        attn_drop=0.0,
        drop=0.0,
        drop_path=0.0,
        v_group=False,
        attn_pre=False,
    ):
        super().__init__()
        self.norm = get_norm(norm_layer)(dim_in) if norm_in else nn.Identity()
        dim_mid = int(dim_in * exp_ratio)
        self.has_skip = (dim_in == dim_out and stride == 1) and has_skip
        self.attn_s = attn_s
        if self.attn_s:
            assert dim_in % dim_head == 0, "dim should be divisible by num_heads"
            self.dim_head = dim_head
            self.window_size = window_size
            self.num_head = dim_in // dim_head
            self.scale = self.dim_head**-0.5
            self.attn_pre = attn_pre
            self.qk = ConvNormAct(
                dim_in,
                int(dim_in * 2),
                kernel_size=1,
                bias=qkv_bias,
                norm_layer="none",
                act_layer="none",
            )
            self.v = ConvNormAct(
                dim_in,
                dim_mid,
                kernel_size=1,
                groups=self.num_head if v_group else 1,
                bias=qkv_bias,
                norm_layer="none",
                act_layer=act_layer,
                inplace=inplace,
            )
            self.attn_drop = nn.Dropout(attn_drop)
        else:
            if v_proj:
                self.v = ConvNormAct(
                    dim_in,
                    dim_mid,
                    kernel_size=1,
                    bias=qkv_bias,
                    norm_layer="none",
                    act_layer=act_layer,
                    inplace=inplace,
                )
            else:
                self.v = nn.Identity()
        self.conv_local = ConvNormAct(
            dim_mid,
            dim_mid,
            kernel_size=dw_ks,
            stride=stride,
            dilation=dilation,
            groups=dim_mid,
            norm_layer="bn_2d",
            act_layer="silu",
            inplace=inplace,
        )
        self.se = (
            SE(dim_mid, rd_ratio=se_ratio, act_layer=get_act(act_layer))
            if se_ratio > 0.0
            else nn.Identity()
        )

        self.proj_drop = nn.Dropout(drop)
        self.proj = ConvNormAct(
            dim_mid,
            dim_out,
            kernel_size=1,
            norm_layer="none",
            act_layer="none",
            inplace=inplace,
        )
        self.drop_path = DropPath(drop_path) if drop_path else nn.Identity()

    def forward(self, x):
        shortcut = x
        x = self.norm(x)
        B, C, H, W = x.shape
        if self.attn_s:
            # padding
            if self.window_size <= 0:
                window_size_W, window_size_H = W, H
            else:
                window_size_W, window_size_H = self.window_size, self.window_size
            pad_l, pad_t = 0, 0
            pad_r = (window_size_W - W % window_size_W) % window_size_W
            pad_b = (window_size_H - H % window_size_H) % window_size_H
            x = F.pad(
                x,
                (
                    pad_l,
                    pad_r,
                    pad_t,
                    pad_b,
                    0,
                    0,
                ),
            )
            n1, n2 = (H + pad_b) // window_size_H, (W + pad_r) // window_size_W
            x = rearrange(
                x, "b c (h1 n1) (w1 n2) -> (b n1 n2) c h1 w1", n1=n1, n2=n2
            ).contiguous()
            # attention
            b, c, h, w = x.shape
            qk = self.qk(x)
            qk = rearrange(
                qk,
                "b (qk heads dim_head) h w -> qk b heads (h w) dim_head",
                qk=2,
                heads=self.num_head,
                dim_head=self.dim_head,
            ).contiguous()
            q, k = qk[0], qk[1]
            attn_spa = (q @ k.transpose(-2, -1)) * self.scale
            attn_spa = attn_spa.softmax(dim=-1)
            attn_spa = self.attn_drop(attn_spa)
            if self.attn_pre:
                x = rearrange(
                    x,
                    "b (heads dim_head) h w -> b heads (h w) dim_head",
                    heads=self.num_head,
                ).contiguous()
                x_spa = attn_spa @ x
                x_spa = rearrange(
                    x_spa,
                    "b heads (h w) dim_head -> b (heads dim_head) h w",
                    heads=self.num_head,
                    h=h,
                    w=w,
                ).contiguous()
                x_spa = self.v(x_spa)
            else:
                v = self.v(x)
                v = rearrange(
                    v,
                    "b (heads dim_head) h w -> b heads (h w) dim_head",
                    heads=self.num_head,
                ).contiguous()
                x_spa = attn_spa @ v
                x_spa = rearrange(
                    x_spa,
                    "b heads (h w) dim_head -> b (heads dim_head) h w",
                    heads=self.num_head,
                    h=h,
                    w=w,
                ).contiguous()
            # unpadding
            x = rearrange(
                x_spa, "(b n1 n2) c h1 w1 -> b c (h1 n1) (w1 n2)", n1=n1, n2=n2
            ).contiguous()
            if pad_r > 0 or pad_b > 0:
                x = x[:, :, :H, :W].contiguous()
        else:
            x = self.v(x)

        x = (
            x + self.se(self.conv_local(x))
            if self.has_skip
            else self.se(self.conv_local(x))
        )

        x = self.proj_drop(x)
        x = self.proj(x)

        x = (shortcut + self.drop_path(x)) if self.has_skip else x
        return x
